Skip to content

ENH: Add native save and read support for SSD#13718

Draft
Aniketsy wants to merge 8 commits intomne-tools:mainfrom
Aniketsy:fix-13328
Draft

ENH: Add native save and read support for SSD#13718
Aniketsy wants to merge 8 commits intomne-tools:mainfrom
Aniketsy:fix-13328

Conversation

@Aniketsy
Copy link
Contributor

@Aniketsy Aniketsy commented Mar 3, 2026

Fixes #13328

@Aniketsy Aniketsy marked this pull request as draft March 3, 2026 17:43
Comment on lines +161 to +175
def _create_cov_callable(self):
"""Recreate covariance callable after initialization or loading."""
self.cov_callable = partial(
_ssd_estimate,
reg=self.reg,
cov_method_params=self.cov_method_params,
info=self.info,
picks=self.picks,
n_fft=self.n_fft,
filt_params_signal=self.filt_params_signal,
filt_params_noise=self.filt_params_noise,
rank=self.rank,
sort_by_spectral_ratio=self.sort_by_spectral_ratio,
)
self.mod_ged_callable = _ssd_mod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would prefer that this method return the partial object and _ssd_mod variable, and leave it to the call site to set the cov_callable and mod_ged_callable attributes. The reason being that, it is not clear for the reader what self._create_cov_callable() is doing, or that it is mutating the class instance. In fact, you could make this a function rather than a method, so that it is clear from the call site what needs to be passed in in order to get the output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this out of __init__ and having this as a method was my suggestion, since we would now also need to create cov_callable for __setstate__, but you're right that its behaviour is obscure.
+1 for returning the functions rather than mutating the instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok thanks for clarifying @tsbinns !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the review, and correction

Copy link
Contributor

@tsbinns tsbinns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really strong start, thanks @Aniketsy! Have some comments below for suggested next steps. Only skimmed the tests, but will have a proper look once the changes below are addressed.

Just to note, the idea would also be to add this functionality to the other decoding classes SPoC, CSP and XdawnTransformer, but that should be a very simple copy-paste job once the template is set in stone for SSD since they all inheret from the _GEDTransformer parent class.

Comment on lines +103 to +108
filt_params_signal,
filt_params_noise,
filt_params_signal=None,
filt_params_noise=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should stay as not having a default. Having a default None could lead users to assume the class will work without specifying them, but that is not the case. Would be better to leave as having no default, and just pass a placeholder value (e.g., None) when we need to init from a state dict.

Comment on lines +189 to +194
saved_version = state.get("mne_version")
if saved_version is not None and saved_version != _mne_version:
warn(
f"The SSD object was saved with MNE-Python {saved_version} but is "
f"being loaded with {_mne_version}. This may cause issues."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary. If we made changes to the SSD class in a future version, we would try to do so in a way that is backwards-compatible and only give a warning like this if we knew with more certainty that some things might not behave as expected.

state.pop("mod_ged_callable", None)
state["_ssd_state"] = True
state["class_name"] = "SSD"
state["mne_version"] = _mne_version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other comment on versioning that we wouldn't normally add this info.

state = self.__dict__.copy()
state.pop("cov_callable", None)
state.pop("mod_ged_callable", None)
state["_ssd_state"] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary to add. It should be sufficient for any checks (e.g., when initing from it) that the state dict has all the required attr keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it.

Comment on lines +161 to +162
def _create_cov_callable(self):
"""Recreate covariance callable after initialization or loading."""
Copy link
Contributor

@tsbinns tsbinns Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _create_cov_callable(self):
"""Recreate covariance callable after initialization or loading."""
def _create_callables(self):
"""Create covariance callable on initialization or state loading."""

Since this is creating both cov_callable and mod_ged_callable, I think a better name for the method would be something like _create_callables. Also, a suggestion for a nitpicky change to the docstring.

Copy link
Contributor

@tsbinns tsbinns Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider also this alternative suggestion of returning the functions rather than assigning them to the instance within the method: #13718 (comment)

Comment on lines +122 to +136
required_keys = (
"info",
"filt_params_signal",
"filt_params_noise",
"n_components",
"filters_",
"patterns_",
)
missing = [k for k in required_keys if k not in info]
if missing:
raise ValueError(
"If 'info' is a dict, it must be a serialized SSD state "
f"(missing keys: {missing}). "
"Otherwise pass an mne.Info object."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things:

  • I think rather than clutter __init__, it would be cleaner to house this in __setstate__.
  • I think that all of the expected state dict keys should be considered required, and checked for.
  • For the error message, we don't need users to know that they can instantiate the SSD class from a state dict. Rather, if the info they pass is a dict and it doesn't match the expected state dict format, we should raise a TypeError and tell them they need to pass an Info object. mne.utils.check._validate_type is an easy way to check for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointers.

Comment on lines +215 to +216
with open(fname, "wb") as fid:
pickle.dump(state, fid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than dumping to a pickle file, it would be good to maintain consistency with how this is handled, e.g., for mne.time_frequency.Spectrum objects (i.e., save in HDF5 format):

def save(self, fname, *, overwrite=False, verbose=None):
"""Save spectrum data to disk (in HDF5 format).
Parameters
----------
fname : path-like
Path of file to save to.
%(overwrite)s
%(verbose)s
See Also
--------
mne.time_frequency.read_spectrum
"""
_, write_hdf5 = _import_h5io_funcs()
check_fname(fname, "spectrum", (".h5", ".hdf5"))
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
out = self.__getstate__()
write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace")

In check_fname, the filetype param could be "ssd" instead of "spectrum".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #13718 (comment)

tsbinns

This comment was marked as duplicate.

tsbinns

This comment was marked as duplicate.

@tsbinns
Copy link
Contributor

tsbinns commented Mar 3, 2026

Ah whoops, was mid-review and didn't see your comments before @scott-huberty!

@Aniketsy
Copy link
Contributor Author

Aniketsy commented Mar 5, 2026

scott-huberty, tsbinns, larsoner thanks for the review and clarification, I’ve addressed most of the points mentioned during the review and am revisiting the changes to ensure I didn’t miss any points.

Just to note, the idea would also be to add this functionality to the other decoding classes SPoC, CSP and XdawnTransformer, but that should be a very simple copy-paste job once the template is set in stone for SSD since they all inheret from the _GEDTransformer parent class.

sure, that sounds good. We can proceed with extending this to the other classes once the SSD implementation is finalized.

@Aniketsy
Copy link
Contributor Author

Aniketsy commented Mar 5, 2026

Ah, I only ran the tests locally. I’ll make sure to build the docs locally as well before pushing the changes in future.

@Aniketsy
Copy link
Contributor Author

Aniketsy commented Mar 9, 2026

@tsbinns I've added verbose --in save() to match other MNE save method to control the overwrite warning log, Please review these changes when you get chance. Thanks!

Copy link
Contributor

@tsbinns tsbinns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Aniketsy, really nice work! Just some very minor comments. How would you feel about extending this framework to the other classes?

Comment on lines +482 to +483
if "info" in state and not isinstance(state["info"], Info):
state["info"] = Info(state["info"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An "info" entry should always be in the state dict. And does the state dict ever contain an instantiated Info object?

I would think it's not needed to nest this in an if statement, just L483 alone should work, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An "info" entry should always be in the state dict. And does the state dict ever contain an instantiated Info object?

I would think it's not needed to nest this in an if statement, just L483 alone should work, no?

yes, i've just looked into this and I agree, my previous assumptions of removing this was wrong so will update. i'll move the Info reconstruction to __setstate__ rather than read_ssd, following the same pattern as spectrum.__setstate__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FAILED mne/decoding/tests/test_ssd.py::test_sklearn_compliance[SSD(filt_params_noise={'h_freq':40.0,'l_freq':0.0},filt_params_signal={'h_freq':30.0,'l_freq':0.0},info=100.0)-check_estimators_pickle] - TypeError: mne._fiff.meas_info.Info() argument after ** must be a mapping, not float
FAILED mne/decoding/tests/test_ssd.py::test_sklearn_compliance[SSD(filt_params_noise={'h_freq':40.0,'l_freq':0.0},filt_params_signal={'h_freq':30.0,'l_freq':0.0},info=100.0)-check_estimators_pickle(readonly_memmap=True)] - TypeError: mne._fiff.meas_info.Info() argument after ** must be a mapping, not float

The sklearn pickle test is failing because SSD accepts info as a float , but __setstate__ was unconditionally calling Info(**state["info"]) which crashes on a float. we need to add an isinstance(state["info"], dict) guard tofix our crash.

Comment on lines +474 to +476
See Also
--------
SSD.save
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add the full path mne.decoding.SSD.save.

Comment on lines +640 to +644
fname_rt = tmp_path / "test_ssd_rt.h5"
ssd.save(fname_rt)
ssd_rt = read_ssd(fname_rt)
assert_array_almost_equal(ssd.filters_, ssd_rt.filters_)
assert_array_almost_equal(ssd.transform(X), ssd_rt.transform(X))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this extra check? Is this not already tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, yes this is redundant check, I'll remove this.

@Aniketsy
Copy link
Contributor Author

Aniketsy commented Mar 9, 2026

@tsbinns thanks! I've addressed the comment you mentioned, should I extend this framework to other classes now.

"The state may be from an incompatible version of MNE."
)
if state["info"] is not None:
state["info"] = Info(**state["info"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tsbinns Kept the if guard in CSP, SPoC, and XdawnTransformer because info can be None for these classes, please let me know what you think here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add native save and read functionality to SSD and SPoC objects

4 participants